142367
@@ -34,6 +34,7 @@
 import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.type.InferTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.type.SqlTypeUtil;
@@ -41,10 +42,12 @@
 import org.apache.calcite.util.CompositeList;
 import org.apache.calcite.util.ImmutableIntList;
 import org.apache.calcite.util.Util;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
 import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
 import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction;
 import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlSumAggFunction;
 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
+import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
 
 import java.math.BigDecimal;
 import java.util.ArrayList;
@@ -280,13 +283,15 @@
private RexNode reduceAvg(
     final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
     final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
     final int iAvgInput = oldCall.getArgList().get(0);
-    RelDataType avgInputType = typeFactory.createTypeWithNullability(
+    final RelDataType avgInputType = typeFactory.createTypeWithNullability(
         getFieldType(oldAggRel.getInput(), iAvgInput), true);
+    final RelDataType sumReturnType = getSumReturnType(
+        rexBuilder.getTypeFactory(), avgInputType, oldCall.getType());
     final AggregateCall sumCall =
         AggregateCall.create(
             new HiveSqlSumAggFunction(
                 oldCall.isDistinct(),
-                oldCall.getAggregation().getReturnTypeInference(),
+                ReturnTypes.explicit(sumReturnType),
                 oldCall.getAggregation().getOperandTypeInference(),
                 oldCall.getAggregation().getOperandTypeChecker()), //SqlStdOperatorTable.SUM,
             oldCall.isDistinct(),
@@ -371,17 +376,21 @@
private RexNode reduceStddev(
     final RexNode argRef =
         rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), false);
     final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);
+    final RelDataType sumReturnType = getSumReturnType(
+        rexBuilder.getTypeFactory(), argRef.getType(), oldCall.getType());
 
     final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY,
         argRef, argRef);
     final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
+    final RelDataType sumSquaredReturnType = getSumReturnType(
+        rexBuilder.getTypeFactory(), argSquared.getType(), oldCall.getType());
 
     final AggregateCall sumArgSquaredAggCall =
         createAggregateCallWithBinding(typeFactory,
             new HiveSqlSumAggFunction(
                 oldCall.isDistinct(),
-                oldCall.getAggregation().getReturnTypeInference(),
-                oldCall.getAggregation().getOperandTypeInference(),
+                ReturnTypes.explicit(sumSquaredReturnType),
+                InferTypes.explicit(Collections.singletonList(argSquared.getType())),
                 oldCall.getAggregation().getOperandTypeChecker()), //SqlStdOperatorTable.SUM,
             argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
 
@@ -397,8 +406,8 @@
private RexNode reduceStddev(
         AggregateCall.create(
             new HiveSqlSumAggFunction(
                 oldCall.isDistinct(),
-                oldCall.getAggregation().getReturnTypeInference(),
-                oldCall.getAggregation().getOperandTypeInference(),
+                ReturnTypes.explicit(sumReturnType),
+                InferTypes.explicit(Collections.singletonList(argOrdinalType)),
                 oldCall.getAggregation().getOperandTypeChecker()), //SqlStdOperatorTable.SUM,
             oldCall.isDistinct(),
             oldCall.isApproximate(),
@@ -532,4 +541,25 @@
private RelDataType getFieldType(RelNode relNode, int i) {
         relNode.getRowType().getFieldList().get(i);
     return inputField.getType();
   }
+
+  private RelDataType getSumReturnType(RelDataTypeFactory typeFactory,
+      RelDataType inputType, RelDataType originalReturnType) {
+    switch (inputType.getSqlTypeName()) {
+      case TINYINT:
+      case SMALLINT:
+      case INTEGER:
+      case BIGINT:
+        return TypeConverter.convert(TypeInfoFactory.longTypeInfo, typeFactory);
+      case TIMESTAMP:
+      case FLOAT:
+      case DOUBLE:
+      case VARCHAR:
+      case CHAR:
+        return TypeConverter.convert(TypeInfoFactory.doubleTypeInfo, typeFactory);
+      case DECIMAL:
+        // We keep precision and scale
+        return originalReturnType;
+    }
+    return null;
+  }
 }
